import abc

import gym
import numpy as np
import metaworld
from metaworld.envs import reward_utils
from metaworld.envs.mujoco.sawyer_xyz.v2 import SawyerFaucetOpenEnvV2, SawyerReachEnvV2

from diffgro.environments.variant import Categorical
from diffgro.environments.metaworld.policies import *

##################################################


class MetaWorldEnv(gym.Env):
    def __init__(self, env_name, max_steps: int = 500, seed: int = 777):
        self.domain_name = "metaworld"
        self.env_name = env_name
        self.task_num = 1
        self.max_steps = max_steps
        self._seed = seed

        self.initialize(env_name)

        self.observation_space = self.env.observation_space
        """ observation spec
            hand position   : [0:3]
            hand gripper    : [3]
            object position : [4:7]
            goal position   : [-3:]
        """
        self.action_space = gym.spaces.Box(
            self.env.action_space.low, self.env.action_space.high, dtype=np.float32
        )

    def reset(self, seed=None):
        if seed is not None:
            self._seed = seed
            self.env.seed(seed)

        task = self.tasks[0]
        self.env.set_task(task)
        self.tasks = self.tasks[1:] + [task]  # change task

        self.timesteps = 0
        self.prev_action = np.zeros(self.action_space.shape)
        obs = self.env.reset()
        return obs

    def step(self, action):
        action = np.clip(action, -1, 1)
        obs, rew, done, info = self.env.step(action)
        self.timesteps += 1

        # Done if success / max step
        info["speed"] = np.linalg.norm(action[:3] - self.prev_action[:3])
        info["force"] = np.linalg.norm(action[:3])
        info["force_axis"] = np.sqrt(np.square(action[:3]))
        info["energy"] = np.sum(np.abs(action[:3]))
        rew = 1 if info["success"] else 0
        if info["success"] or self.timesteps == self.max_steps:
            done = True

        self.prev_action = action
        return obs, rew, done, info

    def render(self, offscreen=True, camera_name="corner", resolution=(480, 480)):
        image = self.env.render(offscreen, camera_name, resolution)
        return image

    def initialize(self, env_name):
        benchmark = metaworld.MT50(seed=self._seed)

        self.env = benchmark.train_classes[env_name]()
        tasks = [task for task in benchmark.train_tasks if task.env_name == env_name]
        if env_name in ["reach-wall-v2", "plate-slide-v2", "plate-slide-back-v2"]:
            tasks = tasks[0:1]
        elif env_name == "button-press-v2":
            tasks = tasks[-8:-5] + tasks[-3:]
        elif env_name == "drawer-close-v2":
            tasks = tasks[-2:]
        else:
            tasks = tasks[:6]
        self.tasks = tasks
        print(f"{env_name} Num T: {len(self.tasks)} {len(self.tasks)}")

    def get_exp(self):
        exp = make_metaworld_policy(self.env_name, variant=None)
        return exp


class MetaWorldVariantEnv(MetaWorldEnv):
    goal_threshold = [0, 0, 0]
    joint_name = None

    def __init__(self, variant_space, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.variant_space = variant_space
        self.variant = None

        self.threshold = 1e9
        self.joint_id = None

    def reset(self, *args, variant=None, warmup=False, **kwargs):
        super().reset(*args, **kwargs)

        self.variant = variant if variant is not None else self.variant_space.sample()
        self.reset_model()

        if warmup:
            self._warmup()

        obs = self.env._get_obs()

        self.damping = False
        return obs

    def step(self, action):
        skill = None
        if type(action) is tuple:
            action, skill = action

        # [arm_speed] : no change
        # action[:3] = action[:3]
        # [wind_xspeed] : x axis change
        # action[0] = action[0] / self.wind_xspeed
        # [wind_yspeed] : y axis change
        # action[1] = action[1] / self.wind_yspeed

        # [resistance]
        if skill in ["push", "pull"]:
            force = np.linalg.norm(action[:3])
            joint_id = self.env.sim.model.joint_name2id(self.__class__.joint_name)
            if force < self.goal_threshold:
                self.env.sim.model.dof_damping[joint_id] = 100000
            else:
                action[:3] = action[:3] / force * 0.3
                self.env.sim.model.dof_damping[joint_id] = 1

        damping_rew, damping_rew_com = None, None
        if self.damping:
            force = np.linalg.norm(action[:3])
            if self.env_name == "drawer-close-v2":
                if self.variant["goal_resistance"] == 1:
                    threshold = 0.33
                elif self.variant["goal_resistance"] == 2:
                    threshold = 0.36
            if self.env_name == "faucet-open-v2":
                if self.variant["goal_resistance"] == 1:
                    threshold = 0.33
                elif self.variant["goal_resistance"] == 2:
                    threshold = 0.36
            if force < threshold:
                action = action * 0.0  # no action applied
                damping_rew = 0.0
                damping_rew_com = 1 - (force - threshold) ** 2
            else:
                damping_rew = 1.0
                damping_rew_com = 1.0
            print(self.timesteps, threshold, force)

        # environment step
        obs, rew, done, info = super().step(action)
        # # # # # # # # # #

        if self.timesteps < 2:
            if self.env_name == "button-press-v2":
                self.obj_pos = obs[4:7]
            elif self.env_name == "drawer-close-v2":
                self.obj_pos = obs[4:7]
            elif self.env_name == "window-close-v2":
                self.obj_pos = obs[4:7]
            elif self.env_name == "door-open-v2":
                self.obj_pos = obs[4:7]
            elif self.env_name == "faucet-open-v2":
                self.obj_pos = obs[4:7]

        # [check when to resist]
        if (
            (skill is None)
            and (not self.damping)
            and (self.variant["goal_resistance"] >= 1)
        ):
            if self.env_name in [
                "button-press-v2",
                "drawer-close-v2",
                "faucet-open-v2",
            ]:
                cur_obj_pos = obs[4:7]
                if cur_obj_pos[1] > self.obj_pos[1]:
                    print(f"Changing Dapming at {self.timesteps}")
                    self.damping = True
            if self.env_name in ["door-open-v2"]:
                cur_obj_pos = obs[4:7]
                if cur_obj_pos[1] < self.obj_pos[1]:
                    print(f"Changing Dapming at {self.timesteps}")
                    self.damping = True
            if self.env_name in ["window-close-v2"]:
                cur_obj_pos = obs[4:7]
                if cur_obj_pos[0] < self.obj_pos[0]:
                    print(f"Changing Dapming at {self.timesteps}")
                    self.damping = True

        info["damping"] = self.damping
        if damping_rew is not None:
            info["damping_rew"] = damping_rew
            info["damping_rew_com"] = damping_rew_com
        return obs, rew, done, info

    def get_exp(self):
        exp = make_metaworld_policy(self.env_name, variant=self.variant)
        return exp

    def reset_model(self):
        if self.variant is not None:
            self._set_arm_speed(self.variant["arm_speed"])
            self._set_goal_resistance(self.variant["goal_resistance"])
            self._set_wind_xspeed(self.variant["wind_xspeed"])
            self._set_wind_yspeed(self.variant["wind_yspeed"])
        else:
            raise ValueError

    def update_variant_space(self, variant_space):
        for k, v in variant_space.variant_config.items():
            self.variant_space.variant_config[k] = v

    def _set_arm_speed(self, arm_speed):
        self.arm_speed = arm_speed

    def _set_wind_xspeed(self, wind_xspeed):
        self.wind_xspeed = wind_xspeed

    def _set_wind_yspeed(self, wind_yspeed):
        self.wind_yspeed = wind_yspeed

    def _set_goal_position(self, goal_pos):
        pass

    def _set_goal_resistance(self, goal_resistance):
        self.goal_threshold = self.__class__.goal_threshold[goal_resistance]

    def _warmup(self):
        action = np.random.normal(0.0, 0.1, size=self.action_space.shape)
        action[-1] = 0

        for _ in range(3):
            self.step(action)
        for _ in range(10):
            self.step(np.zeros_like(action))

        self.timesteps = 0


######################################################################


class DoorOpenVariantEnv(MetaWorldVariantEnv):
    goal_threshold = [0.2, 0.5, 0.8]
    joint_name = "doorjoint"

    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_handle_position(self.variant["handle_position"])

    def _set_handle_position(self, handle_position):
        handle_index = self.env.sim.model.body_name2id("handle")
        self.env.sim.model.body_pos[handle_index] += handle_position


class ButtonPressVariantEnv(MetaWorldVariantEnv):
    goal_threshold = [0.2, 0.5, 0.8]
    joint_name = "btnbox_joint"

    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_button_depth(self.variant["button_depth"])

    def _set_button_depth(self, button_depth):
        pass


class DrawerVariantEnv(MetaWorldVariantEnv):
    goal_threshold = [0.2, 0.5, 0.8]
    joint_name = "goal_slidey"

    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_handle_position(self.variant["handle_position"])

    def _set_handle_position(self, handle_position):
        return

        handle_index = self.env.sim.model.body_name2id("handle")
        self.env.sim.model.body_pos[handle_index] = handle_position


class WindowVariantEnv(MetaWorldVariantEnv):
    goal_threshold = [0.2, 0.5, 0.8]
    joint_name = "window_slide"

    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_handle_position(self.variant["handle_position"])

    def _set_handle_position(self, handle_position):
        handle_index = self.env.sim.model.body_name2id("handle")
        self.env.sim.model.body_pos[handle_index] = handle_position


######################################################################


class FaucetOpenVariantEnv(MetaWorldVariantEnv):
    goal_threshold = [0.2, 0.5, 0.8]
    joint_name = "knob_Joint_1"

    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_initial_angle(self.variant["initial_angle"])

    def _set_initial_angle(self, initial_angle):
        faucet_link_index = self.env.sim.model.body_name2id("faucet_link2")
        self.env.sim.model.body_quat[faucet_link_index] = np.array(
            [np.cos(initial_angle / 2), 0, 0, np.sin(initial_angle / 2)]
        )
        self.env._target_pos = self.env.obj_init_pos + np.array(
            [np.cos(initial_angle) * 0.175, np.sin(initial_angle) * 0.175, 0.125]
        )


class ReachVariantEnv(MetaWorldVariantEnv):
    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_wall_existence(self.variant["wall_existence"])

    def _set_wall_existence(self, wall_existence):
        wall_index = self.env.sim.model.body_name2id("wall")
        if not wall_existence:
            self.env.sim.model.body_pos[wall_index] = [0, 0, -1]
        else:
            self.env.sim.model.body_pos[wall_index] = [0.1, 0.75, 0.06]


class PushVariantEnv(MetaWorldVariantEnv):
    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_wall_existence(self.variant["wall_existence"])

    def _set_wall_existence(self, wall_existence):
        wall_index = self.env.sim.model.body_name2id("wall")
        if not wall_existence:
            self.env.sim.model.body_pos[wall_index] = [0, 0, -1]
        else:
            self.env.sim.model.body_pos[wall_index] = [0.1, 0.75, 0.06]


class PegInsertSideVariantEnv(MetaWorldVariantEnv):
    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_peg_shape(self.variant["peg_shape"])

    def _set_peg_shape(self, peg_shape):
        peg_index = self.env.sim.model.geom_name2id("peg")
        self.env.sim.model.geom_size[peg_index] = peg_shape


####################################################################################


class PushBackVariantEnv(MetaWorldVariantEnv):
    pass


class PlateSlideVariantEnv(MetaWorldVariantEnv):
    pass


class PlateSlideSideVariantEnv(MetaWorldVariantEnv):
    pass


class PlateSlideBackVariantEnv(MetaWorldVariantEnv):
    pass


class PlateSlideBackSideVariantEnv(MetaWorldVariantEnv):
    pass


class PegUnplugSideVariantEnv(MetaWorldVariantEnv):
    pass


class PickPlaceVariantEnv(MetaWorldVariantEnv):
    def reset_model(self):
        super().reset_model()

        if self.variant is not None:
            self._set_wall_existence(self.variant["wall_existence"])

    def _set_wall_existence(self, wall_existence):
        wall_index = self.env.sim.model.body_name2id("wall")
        if not wall_existence:
            self.env.sim.model.body_pos[wall_index] = [0, 0, -1]
        else:
            self.env.sim.model.body_pos[wall_index] = [0.1, 0.75, 0.06]
